#include "Lazy.hpp"

lazySelect::lazySelect(int r, int k, int s, double epsilon) : r(r), k(k), s(s), epsilon(epsilon){}

/*i.i.d*/
// bool lazySelect::WeakOracle(int a, int b){
//     static std::mt19937 gen(std::random_device{}());
//     std::bernoulli_distribution dist(epsilon);
//     return dist(gen) ? !(a < b) : a < b;
// }

/*persistent*/
void lazySelect::InitWeakOracleCache(const std::vector<int>& sample) {
    current_sample = sample;
    sample_index_map.clear();
    size_t s = sample.size();
    size_t total_bits = 2ULL * s * s;
    size_t total_words = (total_bits + 63) / 64;
    oracle_cache.assign(total_words, 0);

    for (size_t i = 0; i < s; ++i) {
        sample_index_map[sample[i]] = i;
    }
}
bool lazySelect::WeakOracle(int a, int b) {
    auto it_a = sample_index_map.find(a);
    auto it_b = sample_index_map.find(b);

    // `static thread_local` improves performance.
    // Removing it doesn't change results, only runtime.
    static thread_local std::mt19937 gen(std::random_device{}());
    std::bernoulli_distribution dist(epsilon); 

    if (it_a != sample_index_map.end() && it_b != sample_index_map.end()) {
        size_t i = it_a->second;
        size_t j = it_b->second;

        if (i == j) return false;

        bool flipped = i > j;
        if (flipped) std::swap(i, j);
        size_t s = current_sample.size();
        size_t pair_idx = i * s + j;
        size_t bit_idx = 2 * pair_idx;
        size_t word_idx = bit_idx / 64;
        size_t offset = bit_idx % 64;
        assert(word_idx < oracle_cache.size());

        uint64_t bits = (oracle_cache[word_idx] >> offset) & 3;

        bool ground_truth = current_sample[i] < current_sample[j];
        bool ans;

        if (bits & 1) {
            bool flipped_stored = (bits >> 1) & 1;
            ans = flipped_stored ? !ground_truth : ground_truth;
        }
        else {
            bool flip = epsilon > 0.0 ? dist(gen) : false;
            ans = flip ? !ground_truth : ground_truth;

            oracle_cache[word_idx] |= (uint64_t(1) << offset);
            if (flip) {
                oracle_cache[word_idx] |= (uint64_t(1) << (offset + 1));
            }
        }
        return flipped ? !ans : ans;
    }
    bool gt = (a < b);
    return dist(gen) ? !gt : gt;
}

int lazySelect::votingselect(std::vector<int> &A, int constant){
    std::vector<int> sample;
    std::vector<int> M;
    weak_cnt = 0, strong_cnt = 0;
    int L = 0, R = 0;
    std::sample(A.begin(), A.end(), std::back_inserter(sample), r, std::mt19937{std::random_device{}()});
    std::sort(sample.begin(), sample.end(), [&](int element_a, int element_b) {
        strong_cnt++;
        return element_a < element_b;
    });
    /*persistent weak oracle*/
    InitWeakOracleCache(sample);
    d = sample[std::max(0, r/2 - k)];
    u = sample[std::min(r - 1, r/2 + k)];
    for(int a : A){
        int count_l = 0, count_r = 0;
        for (int j = 0; j < constant * std::log2(A.size()); j++) {
            int d_prime = sample[std::max(0, ((r / 2) - k - j))];
            int u_prime = sample[std::min(r - 1, ((r / 2) + k + j))];
            weak_cnt += 2;
            if(WeakOracle(a, d_prime)){
                count_l++;
            }
            else{
                count_r++;
            }
            if(WeakOracle(a, u_prime)){
                count_l++;
            }
            else{
                count_r++;
            }
        }
        int total = count_l + count_r;
        if(count_l > (2.0/3.0) * total){
            L++;
        }
        else if(count_r > (2.0/3.0) * total){
            R++;
        }
        else{
            strong_cnt++;
            if(a < d){
                L++;
            }
            else if(a > u){
                R++;
            }
            else{
                M.push_back(a);
            }
        }
    }
    if(L < A.size()/2 && L + M.size() > A.size()/2){
        std::sort(M.begin(), M.end(), [&](int element_a, int element_b) {
            strong_cnt++;
            return element_a < element_b;
        });
        return M[(A.size() / 2) - L - 1];
    }
    else {
        return -1;
    }
}

int64_t lazySelect::GetWeakCnt(){
    return weak_cnt;
}
int64_t lazySelect::GetStrongCnt(){
    return strong_cnt;
}